{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline\n",
    "\n",
    "from functools import partial\n",
    "from pprint import pprint\n",
    "import random\n",
    "from collections import deque\n",
    "import numpy as np\n",
    "from matplotlib import pyplot as plt\n",
    "\n",
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms\n",
    "from torchvision.models import resnet50\n",
    "from torch.utils.data import DataLoader\n",
    "from torch import nn, optim\n",
    "import wids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "# parameters\n",
    "epochs = 3\n",
    "max_steps = 100000\n",
    "batch_size = 32\n",
    "bucket = \"https://storage.googleapis.com/webdataset/fake-imagenet/\"\n",
    "num_workers = 4\n",
    "cache_dir = \"./_cache\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# helpers\n",
    "\n",
    "import time\n",
    "\n",
    "\n",
    "def enumerate_report(seq, delta, growth=1.0):\n",
    "    last = 0\n",
    "    count = 0\n",
    "    for count, item in enumerate(seq):\n",
    "        now = time.time()\n",
    "        if now - last > delta:\n",
    "            last = now\n",
    "            yield count, item, True\n",
    "        else:\n",
    "            yield count, item, False\n",
    "        delta *= growth"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The standard TorchVision transformations.\n",
    "\n",
    "transform_train = transforms.Compose(\n",
    "    [\n",
    "        transforms.RandomResizedCrop(224),\n",
    "        transforms.RandomHorizontalFlip(),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "transform_val = transforms.Compose(\n",
    "    [\n",
    "        transforms.Resize(256),\n",
    "        transforms.CenterCrop(224),\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# The dataset returns dictionaries. This is a small function we transform it\n",
    "# with to get the augmented image and the label.\n",
    "\n",
    "\n",
    "def make_sample(sample, val=False):\n",
    "    image = sample[\".jpg\"]\n",
    "    label = sample[\".cls\"]\n",
    "    if val:\n",
    "        return transform_val(image), label\n",
    "    else:\n",
    "        return transform_train(image), label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# These are standard PyTorch datasets. Download is incremental into the cache.\n",
    "\n",
    "trainset = wids.ShardListDataset(\n",
    "    bucket+\"imagenet-train.json\", cache_dir=cache_dir, keep=True\n",
    ")\n",
    "valset = wids.ShardListDataset(\n",
    "    bucket+\"imagenet-val.json\", cache_dir=cache_dir, keep=True\n",
    ")\n",
    "\n",
    "trainset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Next, we add the transformation to the dataset. Transformations\n",
    "# are executed in sequence. In fact, by default, there is a transformation\n",
    "# that reads and decodes images.\n",
    "\n",
    "trainset.add_transform(make_sample)\n",
    "valset.add_transform(partial(make_sample, val=True))\n",
    "\n",
    "print(trainset[0][0].shape, trainset[0][1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We also need a sampler for the training set. There are three\n",
    "# special samplers in the `wids` package that work particularly\n",
    "# well with sharded datasets:\n",
    "# - `wids.ShardedSampler` shuffles shards and then samples in shards;\n",
    "#   it guarantees that only one shard is used at a time\n",
    "# - `wids.ChunkedSampler` samples by fixed sized chunks, shuffles\n",
    "#   the chunks, and the the samples within each chunk\n",
    "# - `wids.DistributedChunkedSampler` is like `ChunkedSampler` but\n",
    "#   works with distributed training (it first divides the entire\n",
    "#   dataset into per-node chunks, then the per-node chunks into\n",
    "#   smaller chunks, then shuffles the smaller chunks)\n",
    "\n",
    "# trainsampler = wids.ShardedSampler(trainset)\n",
    "# trainsampler = wids.ChunkedSampler(trainset, chunksize=1000, shuffle=True)\n",
    "trainsampler = wids.DistributedChunkedSampler(trainset, chunksize=1000, shuffle=True)\n",
    "\n",
    "plt.plot(list(trainsampler)[:2500])\n",
    "\n",
    "# Note that the sampler shuffles within each shard before moving on to\n",
    "# the next shard. Furthermore, on the first epoch, the sampler\n",
    "# uses the shards in order, but on subsequent epochs, it shuffles\n",
    "# them. This makes testing and debugging easier. If you don't like\n",
    "# this behavior, you can use shufflefirst=True\n",
    "\n",
    "trainsampler.set_epoch(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create data loaders for the training and validation datasets\n",
    "\n",
    "trainloader = DataLoader(trainset, batch_size=batch_size, num_workers=4, sampler=trainsampler)\n",
    "valloader = DataLoader(valset, batch_size=batch_size, shuffle=False, num_workers=4)\n",
    "\n",
    "images, classes = next(iter(trainloader))\n",
    "print(images.shape, classes.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# The usual PyTorch model definition. We use an uninitialized ResNet50 model.\n",
    "\n",
    "model = resnet50(pretrained=False)\n",
    "\n",
    "# Define the loss function and optimizer\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)\n",
    "\n",
    "# Move the model to the GPU if available\n",
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "model = model.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "losses, accuracies = deque(maxlen=100), deque(maxlen=100)\n",
    "\n",
    "steps = 0\n",
    "\n",
    "# Train the model\n",
    "for epoch in range(epochs):\n",
    "    for i, data, verbose in enumerate_report(trainloader, 5):\n",
    "        # get the inputs; data is a list of [inputs, labels]\n",
    "        inputs, labels = data[0].to(device), data[1].to(device)\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = model(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        pred = outputs.cpu().detach().argmax(dim=1, keepdim=True)\n",
    "        correct = pred.eq(labels.cpu().view_as(pred)).sum().item()\n",
    "        accuracy = correct / float(len(labels))\n",
    "\n",
    "        losses.append(loss.item())\n",
    "        accuracies.append(accuracy)\n",
    "        steps += len(labels)\n",
    "\n",
    "        if verbose and len(losses) > 5:\n",
    "            print(\n",
    "                \"[%d, %5d] loss: %.5f correct: %.5f\"\n",
    "                % (epoch + 1, i + 1, np.mean(losses), np.mean(accuracies))\n",
    "            )\n",
    "            running_loss = 0.0\n",
    "\n",
    "        if steps > max_steps:\n",
    "            break\n",
    "    if steps > max_steps:\n",
    "        break\n",
    "\n",
    "print(\"Finished Training\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
